{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"background-color: #C8E6C9; padding: 10px; color: #1b7678\">\n",
    "<b>Pre-requisites</b>: Basic knowledge of Machine Learning and Tabular Problems like Regression and Classification <br></br>\n",
    "<b>Level</b>: Beginner\n",
    "</div>\n",
    "\n",
    "In this tutorial, we will look at how to tackle any tabular machine learning problem (classification or regression) using PyTorch Tabular. We will use the Covertype dataset from the UCI repository. The dataset contains 581012 rows and 54 columns. The dataset is a multi-class classification problem. The goal is to predict the forest cover type from cartographic variables only (no remotely sensed data).\n",
    "\n",
    "In a typical machine learning workflow, we would do the following steps:\n",
    "1. Load the dataset  \n",
    "2. Analyze the dataset  \n",
    "3. Split the dataset into train and test  \n",
    "4. Preprocess the dataset  \n",
    "5. Define the model  \n",
    "6. Train the model  \n",
    "7. Make predictions on new data  \n",
    "8. Evaluate the model  \n",
    "\n",
    "Let's see how we do the same using PyTorch Tabular"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Step 1: Load the Data\n",
    "\n",
    "### Cover Type Dataset\n",
    "\n",
    "Predicting forest cover type from cartographic variables only (no remotely sensed data). The actual forest cover type for a given observation (30 x 30 meter cell) was determined from US Forest Service (USFS) Region 2 Resource Information System (RIS) data. Independent variables were derived from data originally obtained from US Geological Survey (USGS) and USFS data. Data is in raw form (not scaled) and contains binary (0 or 1) columns of data for qualitative independent variables (wilderness areas and soil types).\n",
    "\n",
    "This study area includes four wilderness areas located in the Roosevelt National Forest of northern Colorado. These areas represent forests with minimal human-caused disturbances, so that existing forest cover types are more a result of ecological processes rather than forest management practices.\n",
    "\n",
    "There is a simple utility method in `PyTorch Tabular` to load this particular dataset. It downloads the data from [UCI ML Repository](https://archive.ics.uci.edu/ml/datasets/covertype). The original dataset has two categorical information - Soil Type and Wilderness Area - but one-hot encoded. The utility method converts them to categorical columns to make it more closer to real-life datasets in the wild."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from pytorch_tabular.utils import load_covertype_dataset\n",
    "data, _, _, _ = load_covertype_dataset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Analyze the dataset\n",
    "\n",
    "In this step, we will explore the data to understand the data better. Exploratory Data Analysis (EDA) can be many things and it depends on the data and the problem we are trying to solve. And this can help us understand the data better and make some decisions on how to proceed with the data. But here, we will restrict ourselves to the most basic data analysis; just enough to understand which are the continuous and categorical columns, and if there are any missing values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Index</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008000; text-decoration-color: #008000\">'Wilderness_Area'</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'Soil_Type'</span><span style=\"font-weight: bold\">]</span>, <span style=\"color: #808000; text-decoration-color: #808000\">dtype</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'object'</span><span style=\"font-weight: bold\">)</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;35mIndex\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'Wilderness_Area'\u001b[0m, \u001b[32m'Soil_Type'\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[32m'object'\u001b[0m\u001b[1m)\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from rich import print\n",
    "# One of the easiest ways to identify categorical features is using the pandas select_dtypes function.\n",
    "categorical_features = data.select_dtypes(include=['object'])\n",
    "print(categorical_features.columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But this may not be always reliable. For example, if we have a column called `month` and it has values from 1 to 12, then it is a categorical column. But `select_dtypes` will treat it as a continuous column. So, we need to be careful and use our judgement."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Elevation <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1978</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Elevation \u001b[1;36m1978\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Aspect <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">361</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Aspect \u001b[1;36m361\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Slope <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">67</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Slope \u001b[1;36m67\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Horizontal_Distance_To_Hydrology <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">551</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Horizontal_Distance_To_Hydrology \u001b[1;36m551\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Vertical_Distance_To_Hydrology <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">700</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Vertical_Distance_To_Hydrology \u001b[1;36m700\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Horizontal_Distance_To_Roadways <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5785</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Horizontal_Distance_To_Roadways \u001b[1;36m5785\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Hillshade_9am <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">207</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Hillshade_9am \u001b[1;36m207\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Hillshade_Noon <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">185</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Hillshade_Noon \u001b[1;36m185\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Hillshade_3pm <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">255</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Hillshade_3pm \u001b[1;36m255\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Horizontal_Distance_To_Fire_Points <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5827</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Horizontal_Distance_To_Fire_Points \u001b[1;36m5827\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Cover_Type <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Cover_Type \u001b[1;36m7\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Wilderness_Area <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Wilderness_Area \u001b[1;36m4\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Soil_Type <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">40</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Soil_Type \u001b[1;36m40\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Another way is to use the unique values in each column.\n",
    "for col in data.columns:\n",
    "    print(col, len(data[col].unique()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But this is not reliable either. For example, we have a column called `Soil_Type` and it has 40 unique values. How do we decide if it is a categorical column or a continuous column? We need to use our judgement here as well.\n",
    "\n",
    "And reading the data description, understanding the domain, and using our judgement is the best way to decide if a column is categorical or continuous. \n",
    "\n",
    "Here we will consider `Wilderness_Area` and  `Soil_Type` as categorical features. We know `Cover_Type` is the target column and that makes the rest of the columns continuous features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Data Shape: <span style=\"font-weight: bold\">(</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">581012</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">13</span><span style=\"font-weight: bold\">)</span> | # of cat cols: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> | # of num cols: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Data Shape: \u001b[1m(\u001b[0m\u001b[1;36m581012\u001b[0m, \u001b[1;36m13\u001b[0m\u001b[1m)\u001b[0m | # of cat cols: \u001b[1;36m2\u001b[0m | # of num cols: \u001b[1;36m10\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\"> Features: [</span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">'Elevation'</span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">, </span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">'Aspect'</span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">, </span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">'Slope'</span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">, </span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">'Horizontal_Distance_To_Hydrology'</span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">, </span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">'Vertical_Distance_To_Hydrology'</span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">, </span>\n",
       "<span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">'Horizontal_Distance_To_Roadways'</span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">, </span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">'Hillshade_9am'</span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">, </span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">'Hillshade_Noon'</span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">, </span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">'Hillshade_3pm'</span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">, </span>\n",
       "<span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">'Horizontal_Distance_To_Fire_Points'</span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">, </span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">'Wilderness_Area'</span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">, </span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">'Soil_Type'</span><span style=\"color: #005fff; text-decoration-color: #005fff; font-weight: bold\">]</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;38;5;27m Features: \u001b[0m\u001b[1;38;5;27m[\u001b[0m\u001b[1;38;5;27m'Elevation'\u001b[0m\u001b[1;38;5;27m, \u001b[0m\u001b[1;38;5;27m'Aspect'\u001b[0m\u001b[1;38;5;27m, \u001b[0m\u001b[1;38;5;27m'Slope'\u001b[0m\u001b[1;38;5;27m, \u001b[0m\u001b[1;38;5;27m'Horizontal_Distance_To_Hydrology'\u001b[0m\u001b[1;38;5;27m, \u001b[0m\u001b[1;38;5;27m'Vertical_Distance_To_Hydrology'\u001b[0m\u001b[1;38;5;27m, \u001b[0m\n",
       "\u001b[1;38;5;27m'Horizontal_Distance_To_Roadways'\u001b[0m\u001b[1;38;5;27m, \u001b[0m\u001b[1;38;5;27m'Hillshade_9am'\u001b[0m\u001b[1;38;5;27m, \u001b[0m\u001b[1;38;5;27m'Hillshade_Noon'\u001b[0m\u001b[1;38;5;27m, \u001b[0m\u001b[1;38;5;27m'Hillshade_3pm'\u001b[0m\u001b[1;38;5;27m, \u001b[0m\n",
       "\u001b[1;38;5;27m'Horizontal_Distance_To_Fire_Points'\u001b[0m\u001b[1;38;5;27m, \u001b[0m\u001b[1;38;5;27m'Wilderness_Area'\u001b[0m\u001b[1;38;5;27m, \u001b[0m\u001b[1;38;5;27m'Soil_Type'\u001b[0m\u001b[1;38;5;27m]\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #5f00af; text-decoration-color: #5f00af; font-weight: bold\">Target: Cover_Type</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;38;5;55mTarget: Cover_Type\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# This separation have already been done for you while loading this particular dataset from `PyTorch Tabular`. Let's load the dataset in the right way.\n",
    "data, cat_col_names, num_col_names, target_col = load_covertype_dataset()\n",
    "# Let's also print out a few details\n",
    "print(f\"Data Shape: {data.shape} | # of cat cols: {len(cat_col_names)} | # of num cols: {len(num_col_names)}\")\n",
    "print(f\"[bold dodger_blue2] Features: {num_col_names + cat_col_names}[/bold dodger_blue2]\")\n",
    "print(f\"[bold purple4]Target: {target_col}[/bold purple4]\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"background-color: #b3ebf3; padding: 10px; color: #136876\">\n",
    "<b>Note</b> </br> Supervised Learning reduces to finding a function that maps inputs to outputs. The inputs are called features and the outputs are called targets. The features can be continuous or categorical. The targets can be continuous or categorical. In classification, the targets are categorical. In regression, the targets are continuous.\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Elevation                             <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>\n",
       "Aspect                                <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>\n",
       "Slope                                 <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>\n",
       "Horizontal_Distance_To_Hydrology      <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>\n",
       "Vertical_Distance_To_Hydrology        <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>\n",
       "Horizontal_Distance_To_Roadways       <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>\n",
       "Hillshade_9am                         <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>\n",
       "Hillshade_Noon                        <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>\n",
       "Hillshade_3pm                         <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>\n",
       "Horizontal_Distance_To_Fire_Points    <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>\n",
       "Cover_Type                            <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>\n",
       "Wilderness_Area                       <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>\n",
       "Soil_Type                             <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>\n",
       "dtype: int64\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Elevation                             \u001b[1;36m0\u001b[0m\n",
       "Aspect                                \u001b[1;36m0\u001b[0m\n",
       "Slope                                 \u001b[1;36m0\u001b[0m\n",
       "Horizontal_Distance_To_Hydrology      \u001b[1;36m0\u001b[0m\n",
       "Vertical_Distance_To_Hydrology        \u001b[1;36m0\u001b[0m\n",
       "Horizontal_Distance_To_Roadways       \u001b[1;36m0\u001b[0m\n",
       "Hillshade_9am                         \u001b[1;36m0\u001b[0m\n",
       "Hillshade_Noon                        \u001b[1;36m0\u001b[0m\n",
       "Hillshade_3pm                         \u001b[1;36m0\u001b[0m\n",
       "Horizontal_Distance_To_Fire_Points    \u001b[1;36m0\u001b[0m\n",
       "Cover_Type                            \u001b[1;36m0\u001b[0m\n",
       "Wilderness_Area                       \u001b[1;36m0\u001b[0m\n",
       "Soil_Type                             \u001b[1;36m0\u001b[0m\n",
       "dtype: int64\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Let's also check the data for missing values\n",
    "print(data.isna().sum())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great news! There is no missing values in the dataset. If there were any missing values, we need to handle them. Kaggle has a good tutorial on how to handle missing values. You can find it [here](https://www.kaggle.com/alexisbcook/missing-values).\n",
    "\n",
    "<div style=\"background-color: #b3ebf3; padding: 10px; color: #136876\">\n",
    "<b>Note</b> </br> PyTorch Tabular can deal with mising values in categorical features natively, but missing values in numerical features need to be handled separately.\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Split the dataset into train and test\n",
    "\n",
    "Now, in all tabular problems, when we apply machine learning we need to have a training set, validation set and a test set. We will use the training set to train the model, validation set to make modelling decisions(like the hyperparameters, or kind of model to use etc.) and the test set to evaluate the final model. Since the dataset doesn't come with a test set, we will split the training set into training, validation and test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Train Shape: <span style=\"font-weight: bold\">(</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">371847</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">13</span><span style=\"font-weight: bold\">)</span> | Val Shape: <span style=\"font-weight: bold\">(</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">92962</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">13</span><span style=\"font-weight: bold\">)</span> | Test Shape: <span style=\"font-weight: bold\">(</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">116203</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">13</span><span style=\"font-weight: bold\">)</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Train Shape: \u001b[1m(\u001b[0m\u001b[1;36m371847\u001b[0m, \u001b[1;36m13\u001b[0m\u001b[1m)\u001b[0m | Val Shape: \u001b[1m(\u001b[0m\u001b[1;36m92962\u001b[0m, \u001b[1;36m13\u001b[0m\u001b[1m)\u001b[0m | Test Shape: \u001b[1m(\u001b[0m\u001b[1;36m116203\u001b[0m, \u001b[1;36m13\u001b[0m\u001b[1m)\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "train, test = train_test_split(data, random_state=42, test_size=0.2)\n",
    "train, val = train_test_split(train, random_state=42, test_size=0.2)\n",
    "print(f\"Train Shape: {train.shape} | Val Shape: {val.shape} | Test Shape: {test.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Preprocess the dataset\n",
    "\n",
    "In a typical machine learning project, this is the most time consuming step where we create new features, clean the data, handle missing values, handle outliers, scale the data, encode categorical features and so on. \n",
    "\n",
    "In a scikit-learn based project, a pseudo code for this step would look like this:\n",
    "\n",
    "```python\n",
    "data = create_new_features(data)\n",
    "data = clean_data(data)\n",
    "data = handle_missing_values(data)\n",
    "data = handle_outliers(data)\n",
    "data, cat_encoder = encode_categorical_features(data)\n",
    "data, scaler = scale_data(data)\n",
    "X, y = split_features_target(data)\n",
    "```\n",
    "\n",
    "But one of the allures of deep learning is that we don't need to spend time on feature engineering. We can just use the raw data and let the model figure out the best features to use. But we still need to do some data preparation. And for that, PyTorch Tabular takes care of some of these needs:\n",
    "\n",
    "- Missing values in categorical features are handled natively\n",
    "- Categorical features are encoded automatically using embeddings\n",
    "- Continuous features are scaled automatically using StandardScaler\n",
    "- Date features like month, day, year are extracted automatically\n",
    "- Target transformation like log, power, quantile, box-cox can be enabled with a parameter. This will also handle the inverse tranformation automatically.\n",
    "- Continuous features can be transformed using box-cox, quantile normal etc. with a parameter\n",
    "\n",
    "While we have all these features, we can also choose to do any of these manually. For example, we can choose to encode categorical features using one hot encoding or target encoding and consider them as continuous features. We can also choose to scale the continuous features using MinMaxScaler or RobustScaler and turn off the automatic scaling.\n",
    "\n",
    "So, here, we won't be doing any of these. We will just use the data as is and let PyTorch Tabular handle the rest.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5: Define the Model\n",
    "\n",
    "Now, we will define the model. In a `scikit-learn` workflow, we would have done the following steps:\n",
    "\n",
    "```python\n",
    "model = SomeModel(**parameters)\n",
    "```\n",
    "\n",
    "This is where PyTorch Tabular is different because we need to define a few configs before we define the model. One of the reasons is that PyTorch Tabular model handles a lot of things under the hood. So, we need to tell the model what kind of data we are dealing with. We also need to define the training dynamics, along with the model parameters. The configs we need to define are:\n",
    "\n",
    "\n",
    "1. `DataConfig` - This is where we define the data related configs like the target column, categorical columns, continuous columns, date columns, categorical embedding dimensions, etc. But the good news is that most of these are optional. If we don't define them, PyTorch Tabular will try to infer them from the data or have thumb rules to handle them. The bare minimum we need to define is the target column name, continuous columns and categorical columns. Categorical columns are embedded by default, numerical columns scaled by default and date columns are extracted by default.\n",
    "\n",
    "2. `TrainerConfig` - This is where we define the training related configs like the batch size, number of epochs, early stopping, etc. Again, all of these are optional. If we don't define them, PyTorch Tabular will use some default values. By default, `PyTorch Tabular` runs with a batch size of 64, with early stopping with a patience of 3 epochs and checkpointing enabled. This means that the model will be saved at the end of every epoch and the best model will be saved. The model will stop training if the validation loss doesn't improve for 3 epochs. Although all of `TrainerConfig` is optional, it is infinitely customizable. And with the entire PyTorch Lightning `Trainer` exposed, either through explicit parameters in `TrainerConfig` or through a catch-all `trainer_kwargs` parameter in `TrainerConfig`.\n",
    "\n",
    "3. `OptimizerConfig` - This is where we define the optimizer related configs like the the kind of optimizer, weight decay, learning rate schedulers, etc. Again, all of these are optional. If we don't define them, PyTorch Tabular will use some default values. By default, `PyTorch Tabular` uses the `Adam` optimizer. It doesn't use any learning rate decay by default. Although all of `OptimizerConfig` is optional, it is also customizable.\n",
    "\n",
    "4. `ExperimentConfig` - This is where we define how to track the experiment for logging and reproducibility. By default, `PyTorch Tabular` uses `tensorboard` for logging. But we can also use `wandb`. We can also choose to not log anything (although not recommended) by not defining an `ExperimentConfig`.\n",
    "\n",
    "5. `<ModelSpecificConfig>` - This is where we define which model to use and the corresponding hyperparameters. In `PyTorch Tabular`, each of the implemented model has their own config class. For example, if we want to use `TabNet`, we need to define `TabNetConfig`. If we want to use `GANDALF`, we need to define `GANDALFConfig`, and so on. Each of these config classes have their own set of model specific hyperparameters, as well as some common parameters like the loss function, metrics, learning rate, etc. Again, all of these are optional. If we don't define them, PyTorch Tabular will use some default values. Learning RAte is set to 1e-3 by default. The loss function is set to `CrossEntropyLoss` for classification and `MSELoss` for regression. The metrics are set to `Accuracy` for classification and `MSE` for regression. And all the model specific hyperparameters are set to suggested default values in their respective papers, or some default values that work well in practice.\n",
    "\n",
    "\n",
    "Here, let's use GANDALF Model. We will define the configs as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "\n",
    "from pytorch_tabular.models import GANDALFConfig\n",
    "from pytorch_tabular.config import (\n",
    "    DataConfig,\n",
    "    OptimizerConfig,\n",
    "    TrainerConfig,\n",
    ")\n",
    "\n",
    "data_config = DataConfig(\n",
    "    target=[\n",
    "        target_col\n",
    "    ],  # target should always be a list\n",
    "    continuous_cols=num_col_names,\n",
    "    categorical_cols=cat_col_names,\n",
    ")\n",
    "trainer_config = TrainerConfig(\n",
    "    batch_size=1024,\n",
    "    max_epochs=100,\n",
    ")\n",
    "optimizer_config = OptimizerConfig()\n",
    "model_config = GANDALFConfig(\n",
    "    task=\"classification\",\n",
    "    gflu_stages=6,\n",
    "    gflu_feature_init_sparsity=0.3,\n",
    "    gflu_dropout=0.0,\n",
    "    learning_rate=1e-3,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have defined all the configs, we can define the `TabularModel`. Apart from the configs, there are some additional parameters we can pass to the model to control the verbosity of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">04:39:30</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">992</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">140</span><span style=\"font-weight: bold\">}</span> - INFO - Experiment Tracking is turned off           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m07\u001b[0m \u001b[1;92m04:39:30\u001b[0m,\u001b[1;36m992\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m140\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from pytorch_tabular import TabularModel\n",
    "\n",
    "tabular_model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    "    verbose=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that since we passed `verbose=True`, it has already logged that the Experiment Tracking is disabled."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6 : Train the model\n",
    "\n",
    "Now, we can train the model. In the `scikit-learn` workflow, we would have done the following:\n",
    "\n",
    "```python\n",
    "model.fit(X_train, y_train)\n",
    "```\n",
    "\n",
    "In PyTorch Tabular, there are two ways we can do this:\n",
    "- High-Level API - A `fit` method which is very similar to the scikit-learn API, but the fit method has a lot more parameters to control the training dynamics. This is the recommended way to train the model.\n",
    "- Low-Level API - A collection of methods - `prepare_dataloader`, `prepare_model`, and `train`. This is for advanced users who want to have more control over the training process.\n",
    "\n",
    "Let's stick to the high-level API in this introductory tutorial. We will use the `fit` method to train the model. There is only one compulsory parameter for the `fit` method - `train` data. We can also pass the `validation` data explicitly. If not provided, it'll use 20% of training data as validation data. In addition to this there are many other parameters like custom loss functions, metrics, custom optimizers, etc. which can be used to make the training process more customizable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">04:39:31</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">059</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">524</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the DataLoaders                   \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m07\u001b[0m \u001b[1;92m04:39:31\u001b[0m,\u001b[1;36m059\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m524\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders                   \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">04:39:31</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">500</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_datamodul<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e:499</span><span style=\"font-weight: bold\">}</span> - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m07\u001b[0m \u001b[1;92m04:39:31\u001b[0m,\u001b[1;36m500\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:499\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">04:39:32</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">358</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">574</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Model: GANDALFModel           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m07\u001b[0m \u001b[1;92m04:39:32\u001b[0m,\u001b[1;36m358\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m574\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: GANDALFModel           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">04:39:32</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">591</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">340</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Trainer                       \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m07\u001b[0m \u001b[1;92m04:39:32\u001b[0m,\u001b[1;36m591\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m340\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer                       \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">04:39:32</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">839</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">630</span><span style=\"font-weight: bold\">}</span> - INFO - Auto LR Find Started                        \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m07\u001b[0m \u001b[1;92m04:39:32\u001b[0m,\u001b[1;36m839\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m630\u001b[0m\u001b[1m}\u001b[0m - INFO - Auto LR Find Started                        \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "888b2d805e274d468ccf956a7099da19",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_steps=100` reached.\n",
      "Learning rate set to 0.02089296130854041\n",
      "Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_6d1c6109-882a-4b7f-939c-2d42ecd8ff06.ckpt\n",
      "Restored all states from the checkpoint at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_6d1c6109-882a-4b7f-939c-2d42ecd8ff06.ckpt\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">04:39:37</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">498</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">643</span><span style=\"font-weight: bold\">}</span> - INFO - Suggested LR: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.02089296130854041</span>. For plot \n",
       "and detailed analysis, use `find_learning_rate` method.                                                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m07\u001b[0m \u001b[1;92m04:39:37\u001b[0m,\u001b[1;36m498\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m643\u001b[0m\u001b[1m}\u001b[0m - INFO - Suggested LR: \u001b[1;36m0.02089296130854041\u001b[0m. For plot \n",
       "and detailed analysis, use `find_learning_rate` method.                                                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">04:39:37</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">500</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">652</span><span style=\"font-weight: bold\">}</span> - INFO - Training Started                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m07\u001b[0m \u001b[1;92m04:39:37\u001b[0m,\u001b[1;36m500\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m652\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _backbone        │ GANDALFBackbone  │ 42.4 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _embedding_layer │ Embedding1dLayer │    896 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ _head            │ Sequential       │    252 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss │      0 │\n",
       "└───┴──────────────────┴──────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ GANDALFBackbone  │ 42.4 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer │    896 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ _head            │ Sequential       │    252 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss │      0 │\n",
       "└───┴──────────────────┴──────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 43.6 K                                                                                           \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 43.6 K                                                                                               \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 0                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 43.6 K                                                                                           \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 43.6 K                                                                                               \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d560c8740d604644bd93dcc194a698c9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">04:41:25</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">635</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">663</span><span style=\"font-weight: bold\">}</span> - INFO - Training the model completed                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m07\u001b[0m \u001b[1;92m04:41:25\u001b[0m,\u001b[1;36m635\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m663\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">04:41:25</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">636</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1487</span><span style=\"font-weight: bold\">}</span> - INFO - Loading the best model                     \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m07\u001b[0m \u001b[1;92m04:41:25\u001b[0m,\u001b[1;36m636\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1487\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model                     \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<pytorch_lightning.trainer.trainer.Trainer at 0x7f1dbab82c10>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tabular_model.fit(train=train, validation=val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 7: Making predictions on new data\n",
    "\n",
    "Now that we have trained the model, we can make predictions on new data. In a `scikit-learn` workflow, we would have done the following:\n",
    "\n",
    "```python\n",
    "y_pred = model.predict(X_test)\n",
    "y_pred_proba = model.predict_proba(X_test)\n",
    "```\n",
    "\n",
    "In PyTorch Tabular, we can do something very similar. We can use the `predict` method to make predictions on new data. This method returns the predictions as a pandas dataframe predictions. For classification problems, it returns the class probabilities, and final prediction class based on 0.5 threshold. All we have to do is pass in a dataframe with atleast all the features that was used for training.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>1_probability</th>\n",
       "      <th>2_probability</th>\n",
       "      <th>3_probability</th>\n",
       "      <th>4_probability</th>\n",
       "      <th>5_probability</th>\n",
       "      <th>6_probability</th>\n",
       "      <th>7_probability</th>\n",
       "      <th>prediction</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>250728</th>\n",
       "      <td>0.901409</td>\n",
       "      <td>0.001267</td>\n",
       "      <td>1.025811e-08</td>\n",
       "      <td>9.070018e-08</td>\n",
       "      <td>0.000040</td>\n",
       "      <td>3.519856e-08</td>\n",
       "      <td>9.728358e-02</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>246788</th>\n",
       "      <td>0.156802</td>\n",
       "      <td>0.843021</td>\n",
       "      <td>8.172029e-07</td>\n",
       "      <td>2.142834e-09</td>\n",
       "      <td>0.000171</td>\n",
       "      <td>2.749175e-07</td>\n",
       "      <td>4.734764e-06</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>407714</th>\n",
       "      <td>0.001035</td>\n",
       "      <td>0.969636</td>\n",
       "      <td>4.896594e-03</td>\n",
       "      <td>4.262948e-06</td>\n",
       "      <td>0.019907</td>\n",
       "      <td>4.521028e-03</td>\n",
       "      <td>8.038983e-07</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25713</th>\n",
       "      <td>0.289917</td>\n",
       "      <td>0.709881</td>\n",
       "      <td>1.039616e-05</td>\n",
       "      <td>5.966012e-08</td>\n",
       "      <td>0.000152</td>\n",
       "      <td>3.714674e-05</td>\n",
       "      <td>1.749813e-06</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21820</th>\n",
       "      <td>0.000729</td>\n",
       "      <td>0.870874</td>\n",
       "      <td>2.740137e-05</td>\n",
       "      <td>3.132881e-06</td>\n",
       "      <td>0.128357</td>\n",
       "      <td>9.515504e-06</td>\n",
       "      <td>9.939656e-09</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        1_probability  2_probability  3_probability  4_probability  \\\n",
       "250728       0.901409       0.001267   1.025811e-08   9.070018e-08   \n",
       "246788       0.156802       0.843021   8.172029e-07   2.142834e-09   \n",
       "407714       0.001035       0.969636   4.896594e-03   4.262948e-06   \n",
       "25713        0.289917       0.709881   1.039616e-05   5.966012e-08   \n",
       "21820        0.000729       0.870874   2.740137e-05   3.132881e-06   \n",
       "\n",
       "        5_probability  6_probability  7_probability  prediction  \n",
       "250728       0.000040   3.519856e-08   9.728358e-02           1  \n",
       "246788       0.000171   2.749175e-07   4.734764e-06           2  \n",
       "407714       0.019907   4.521028e-03   8.038983e-07           2  \n",
       "25713        0.000152   3.714674e-05   1.749813e-06           2  \n",
       "21820        0.128357   9.515504e-06   9.939656e-09           2  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_df = tabular_model.predict(test)\n",
    "pred_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Step 8: Evaluating the Model\n",
    "\n",
    "Now, we can evaluate the model. In the `scikit-learn` workflow, we would have done the following:\n",
    "\n",
    "```python\n",
    "pred_df = model.predict(X_test)\n",
    "accuracy = accuracy_score(y_test, pred_df)\n",
    "```\n",
    "\n",
    "In PyTorch Tabular, there are two ways we can do this:\n",
    "- Get the predictions on the test set and calculate the metrics manually\n",
    "- Use the `evaluate` method which will return the metrics (the same ones we have defined during training)\n",
    "\n",
    "We will see the second way here. We can use the `evaluate` method to evaluate the model on the test set. This method returns a dictionary of metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fa942695347d4009a33987646d1ab653",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     0.878411054611206     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.2998563051223755     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    0.878411054611206    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.2998563051223755    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "result = tabular_model.evaluate(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'test_loss': 0.2998563051223755, 'test_accuracy': 0.878411054611206}]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Step 9: Saving and Loading the Model\n",
    "\n",
    "After the model is trained, we can save the model and load it later to make predictions on new data. In a `scikit-learn` workflow, we would have done the following:\n",
    "\n",
    "```python\n",
    "joblib.dump(model, \"model.joblib\")\n",
    "model = joblib.load(\"model.joblib\")\n",
    "```\n",
    "\n",
    "In PyTorch Tabular, we can do something very similar. We can use the `save_model` method to save the model. This method saves everything required to make predictions on new data. By default it also saves the datamodule, which contains the training data, validation data, and test data as well. But we can choose to not save the datamodule by setting `inference_only=True`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">04:43:51</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">268</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1531</span><span style=\"font-weight: bold\">}</span> - WARNING - Directory is not empty. Overwriting the \n",
       "contents.                                                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m07\u001b[0m \u001b[1;92m04:43:51\u001b[0m,\u001b[1;36m268\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1531\u001b[0m\u001b[1m}\u001b[0m - WARNING - Directory is not empty. Overwriting the \n",
       "contents.                                                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tabular_model.save_model(\"examples/basic\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can load the saved model using the `load_model` method. This method returns the model and the datamodule. We can use the model to make predictions on new data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">04:43:51</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">948</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">165</span><span style=\"font-weight: bold\">}</span> - INFO - Experiment Tracking is turned off           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m07\u001b[0m \u001b[1;92m04:43:51\u001b[0m,\u001b[1;36m948\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m165\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">01</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">04:43:51</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">953</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">340</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Trainer                       \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m01\u001b[0m-\u001b[1;36m07\u001b[0m \u001b[1;92m04:43:51\u001b[0m,\u001b[1;36m953\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m340\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer                       \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "loaded_model = TabularModel.load_model(\"examples/basic\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f28c6a1d89674711b4a09467fcfb9291",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     0.878411054611206     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.2998563051223755     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    0.878411054611206    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.2998563051223755    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Let's check if we get the same result on test data using the loaded model\n",
    "result = loaded_model.evaluate(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"background-color: #C8E6C9; padding: 10px; color: #1b7678\">\n",
    "<b>Congrats!</b>: You have trained a SOTA deep learning model data. Things you can try: </br>\n",
    "<ol>\n",
    "<li>Check out the <a src=https://pytorch-tabular.readthedocs.io/en/latest/>PyTorch Tabular Documentation</a> to learn more about the library</li>\n",
    "<li>Use alternate models like TabNet, CategoryEmbedding, etc.</li>\n",
    "<li>Use different datasets and try out the workflow.</li>\n",
    "<li>Check out other tutorials and how-to guides in the documentation.</li>\n",
    "</ol>\n",
    "Now try to use these features in your own projects and Kaggle competitions. If you have any questions, please feel free to ask them in the <a src=https://github.com/manujosephv/pytorch_tabular/discussions>GitHub Discussions</a>\n",
    "</div>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "c1644b008eb6c88a0ca3600445c5d81cce1a68be89616a3704b32e9da15a977e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
